// Copyright © 2025 Bjango. All rights reserved.

#pragma once

#include "LPCvocoder.hpp"
#include "AudioUtils.hpp"
#include "PeakMeter.hpp"
#import "ParameterAddresses.h"
#import <AudioToolbox/AudioToolbox.h>
#import <algorithm>
#import <vector>
#import <span>
#import <memory>

static constexpr float levelDecayFactor = 1.05f;

class DSPKernel {
public:
    typedef void (*MeterCallback)(void* context, float synthLevel, float voiceLevel, float outLevel);

    void setLevelsMeterCallback(MeterCallback callback, void* context) {
        levelsCallback = callback;
        levelsCallbackContext = context;
    }
    
    void initialize(double inSampleRate, bool voiceOnlyOutput, float wetMix, float dryMix, float compressorRatio, float compressorThreshold, float compressorAttackMs, float compressorReleaseMs) {
        sampleRate = inSampleRate;
        
        processor = std::make_unique<LPCVocoderProcessor>(inSampleRate);
        processor->setQuality(quality);
        processor->setWhiteNoiseLevel(whiteNoiseLevelDb);
        processor->setWhiteNoiseThreshold(whiteNoiseThreshold);
        processor->setOutputGain(outputGainDb);
        processor->setVoiceInGain(voiceInGainDb);
        processor->setSynthInGain(synthInGainDb);
        processor->setVoiceMode(voiceMode);
        processor->setFormantShift(formantShiftCents);
        processor->setHighPassFilter(highPassFilter);
        processor->setLowPassFilter(lowPassFilter);
        processor->setDownsampleTo31Khz(crush);
        processor->set16Bit(crush);
        processor->setCompressEnabled(compress);
        
        // hard coded flags
        processor->setVoiceOnlyOutput(voiceOnlyOutput);
        processor->setWetMix(wetMix);
        processor->setDryMix(dryMix);
        processor->setCompressorSettings(compressorRatio, compressorThreshold, compressorAttackMs, compressorReleaseMs);
        
        peakMeter = std::make_unique<PeakMeter>();
    }

    void deInitialize() {
        processor.reset();
    }

    bool isBypassed() {
        return bypassed;
    }

    void setBypass(bool shouldBypass) {
        bypassed = shouldBypass;
    }
    
    void setLicensed(bool isLicensed) {
        licensed = isLicensed;
    }

    // MARK: - Parameter Getter / Setter
    void setParameter(AUParameterAddress address, AUValue value) {
        switch (address) {
            case RobotRocketParameterAddress::quality:
                quality = value;
                if (processor) {
                    processor->setQuality(value);
                }
                break;
            case RobotRocketParameterAddress::whiteNoiseLevel:
                whiteNoiseLevelDb = value;
                if (processor) {
                    processor->setWhiteNoiseLevel(value);
                }
                break;
            case RobotRocketParameterAddress::whiteNoiseThreshold:
                whiteNoiseThreshold = value;
                if (processor) {
                    processor->setWhiteNoiseThreshold(value);
                }
                
                break;
            case RobotRocketParameterAddress::outputLevel:
                outputGainDb = value;
                if (processor) {
                    processor->setOutputGain(value);
                }
                
                break;
            case RobotRocketParameterAddress::voiceInGain:
                voiceInGainDb = value;
                if (processor) {
                    processor->setVoiceInGain(value);
                }
                if (peakMeter) {
                    float voiceInGainFactor = AudioUtils::decibelToScale(value);
                    peakMeter->setVoiceGain(voiceInGainFactor);
                }
                
                break;
            case RobotRocketParameterAddress::synthInGain:
                synthInGainDb = value;
                if (processor) {
                    processor->setSynthInGain(value);
                }
                if (peakMeter) {
                    float synthInGainFactor = AudioUtils::decibelToScale(value);
                    peakMeter->setSynthGain(synthInGainFactor);
                }
                
                break;
            case RobotRocketParameterAddress::voiceMode:
                voiceMode = static_cast<int>(value);
                if (processor) {
                    processor->setVoiceMode(voiceMode);
                }
                
                break;
            case RobotRocketParameterAddress::formantShift:
                formantShiftCents = value;
                if (processor) {
                    processor->setFormantShift(value);
                }
                
                break;
            case RobotRocketParameterAddress::highPassFilter:
                highPassFilter = (value == 1);
                if (processor) {
                    processor->setHighPassFilter(highPassFilter);
                }
                
                break;
            case RobotRocketParameterAddress::lowPassFilter:
                lowPassFilter = (value == 1);
                if (processor) {
                    processor->setLowPassFilter(lowPassFilter);
                }
                
                break;
            case RobotRocketParameterAddress::crush:
                crush = (value == 1);
                if (processor) {
                    processor->setDownsampleTo31Khz(crush);
                    processor->set16Bit(crush);
                }
                
                break;
            case RobotRocketParameterAddress::compress:
                compress = (value == 1);
                if (processor) {
                    processor->setCompressEnabled(compress);
                }
                
                break;
        }
    }

    AUAudioFrameCount maximumFramesToRender() const {
        return mMaxFramesToRender;
    }

    void setMaximumFramesToRender(const AUAudioFrameCount &maxFrames) {
        mMaxFramesToRender = maxFrames;
    }

    void process(std::span<float const*> inputBuffers, std::span<float *> outputBuffers, AUEventSampleTime /*bufferStartTime*/, AUAudioFrameCount frameCount) {
        // pass through if bypassed, not licensed or input buffers are missing
        if (bypassed || !licensed || inputBuffers.size() < 2) {
            for (UInt32 channel = 0; channel < std::min(inputBuffers.size(), outputBuffers.size()); ++channel) {
                std::copy_n(inputBuffers[channel], frameCount, outputBuffers[channel]);
            }
            return;
        }

        float* vocalIn = const_cast<float*>(inputBuffers[0]);
        float* synthIn = const_cast<float*>(inputBuffers[1]);
        
        // to keep the vocoder simple it always expects a left and right output, but if the output is mono, just give it the output twice for now
        float* outL = outputBuffers[0];
        float* outR = outputBuffers.size() > 1 ? outputBuffers[1] : outputBuffers[0];

        processor->process(synthIn, vocalIn, outL, outR, static_cast<int>(frameCount));
        
        calculateLevels(synthIn, vocalIn, outL, static_cast<int>(frameCount));
    }
    
    void calculateLevels(float* synthIn, float* voiceIn, float* out, int numFrames) {
        peakMeter->calculateLevels(synthIn, voiceIn, out, numFrames);

        meterSampleCounter += numFrames;

        // periodically report progress back to the UI
        if (meterSampleCounter >= meterUpdateIntervalSamples) {
            if (levelsCallback) {
                levelsCallback(levelsCallbackContext, peakMeter->getSynthPeak(), peakMeter->getVoicePeak(), peakMeter->getOutPeak());
            }

            meterSampleCounter = 0;
        }
    }

    void handleOneEvent(AUEventSampleTime now, AURenderEvent const *event) {
        switch (event->head.eventType) {
            case AURenderEventParameter: {
                handleParameterEvent(now, event->parameter);
                break;
            }
            default:
                break;
        }
    }

    void handleParameterEvent(AUEventSampleTime /*now*/, AUParameterEvent const& parameterEvent) {
        setParameter(parameterEvent.parameterAddress, parameterEvent.value);
    }

private:
    std::unique_ptr<LPCVocoderProcessor> processor;
    std::unique_ptr<PeakMeter> peakMeter;
    
    double sampleRate = 31250.0;
    
    // configurable parameters
    float quality = 0.0;
    float whiteNoiseLevelDb = 0.0;
    float whiteNoiseThreshold = 0.0;
    float outputGainDb = 0.0;
    float voiceInGainDb = 0.0;
    float synthInGainDb = 0.0;
    int voiceMode = 0;
    float formantShiftCents = 0.0;
    bool highPassFilter = false;
    bool lowPassFilter = false;
    bool crush = false;
    bool compress = false;
    
    bool licensed = true;
    bool bypassed = false;
    AUAudioFrameCount mMaxFramesToRender = 1024;
    
    // volume meter variables
    int meterSampleCounter = 0;
    const int meterUpdateIntervalSamples = 48000 / 15; // aim for 15 updates a second
    
    MeterCallback levelsCallback = nullptr;
    void* levelsCallbackContext = nullptr;
};
